{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "GAMMA = 0.9\n",
    "LEARNING_RATE = 0.6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    with open(\"Qtable.txt\", \"rb\") as f:\n",
    "        \"\"\"\n",
    "        This code is responsible for loading Qtable.txt if already present\n",
    "        \"\"\"\n",
    "        Q_table = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def next_best_action(state: int, Q_table: np.ndarray) -> int:\n",
    "    \"\"\"\n",
    "    Return the most suitable action value from Q_table or a random action\n",
    "    if there is no np.argmax(Q_table[state]).\n",
    "    >>> next_best_action(1, []) in [0, 1, 2, 3]\n",
    "    True\n",
    "    \"\"\"\n",
    "    action = np.argmax(Q_table[state])\n",
    "    return action if action is not None else np.random.choice([0, 1, 2, 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def state_action_reward(player: object, x_food: int, y_food: int) -> tuple:\n",
    "    \"\"\"\n",
    "    This function returns state, action and reward to update the Qtable.\n",
    "    \"\"\"\n",
    "    x_agent, y_agent = player.body[0].pos\n",
    "    current_state = state(player, x_food, y_food)\n",
    "    current_action = next_best_action(state, Q_table)\n",
    "    current_reward = reward(player, x_food, y_food)\n",
    "    return (current_state, current_action, current_reward)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# States to consider:\n",
    "#   * Food's relative positioning wrt the head\n",
    "# Checking for food in nine directions\n",
    "###\n",
    "# #\n",
    "###\n",
    "#   * obstruction Ahead, Right, Left\n",
    "\n",
    "def state(player: object, x_food: int, y_food: int) -> int:\n",
    "    \"\"\"\n",
    "    This function Checks for the food in nine directions and returns state.\n",
    "    \"\"\"\n",
    "    x_agent, y_agent = player.body[0].pos\n",
    "    states = []\n",
    "\n",
    "    # Code to check for obstacles in front of the agent\n",
    "    DangerAhead = (\n",
    "        (player.dirnx == -1 and x_agent + player.delx < 0)\n",
    "        or (player.dirnx == -1 and ((x_agent + player.delx, y_agent) in player.body))\n",
    "        or (player.dirnx == 1 and x_agent + player.delx > 14)\n",
    "        or (player.dirnx == 1 and ((x_agent + player.delx, y_agent) in player.body))\n",
    "        or (player.dirny == 1 and y_agent + player.dely > 14)\n",
    "        or (player.dirny == 1 and ((x_agent, y_agent + player.dely) in player.body))\n",
    "        or (player.dirny == -1 and y_agent + player.dely < 0)\n",
    "        or (player.dirny == -1 and ((x_agent, y_agent + player.dely) in player.body))\n",
    "    )\n",
    "    # Code to check for obstacles to the left of the agent\n",
    "    DangerLeft = (\n",
    "        (player.dirnx == -1 and y_agent + 1 > 14)\n",
    "        or (player.dirnx == -1 and ((x_agent, y_agent + 1) in player.body))\n",
    "        or (player.dirnx == 1 and y_agent - 1 < 0)\n",
    "        or (player.dirnx == 1 and ((x_agent, y_agent - 1) in player.body))\n",
    "        or (player.dirny == 1 and x_agent + 1 > 14)\n",
    "        or (player.dirny == 1 and ((x_agent + 1, y_agent) in player.body))\n",
    "        or (player.dirny == -1 and x_agent - 1 < 0)\n",
    "        or (player.dirny == -1 and ((x_agent - 1, y_agent) in player.body))\n",
    "    )\n",
    "    # Code to check for obstacles to the right of the agent\n",
    "    DangerRight = (\n",
    "        (player.dirnx == -1 and y_agent - 1 < 0)\n",
    "        or (player.dirnx == -1 and ((x_agent, y_agent - 1) in player.body))\n",
    "        or (player.dirnx == 1 and y_agent + 1 > 14)\n",
    "        or (player.dirnx == 1 and ((x_agent, y_agent + 1) in player.body))\n",
    "        or (player.dirny == 1 and x_agent - 1 < 0)\n",
    "        or (player.dirny == 1 and ((x_agent - 1, y_agent) in player.body))\n",
    "        or (player.dirny == -1 and x_agent + 1 > 14)\n",
    "        or (player.dirny == -1 and ((x_agent + 1, y_agent) in player.body))\n",
    "    )\n",
    "    # Code for food straight wrt head\n",
    "    FoodStraightAhead = (\n",
    "        (player.dirnx == 1 and y_agent == y_food and x_food > x_agent)\n",
    "        or (player.dirnx == -1 and y_agent == y_food and x_food < x_agent)\n",
    "        or (player.dirny == 1 and y_agent < y_food and x_food == x_agent)\n",
    "        or (player.dirny == -1 and y_agent > y_food and x_food == x_agent)\n",
    "    )\n",
    "    # Code for food which is ahead and right wrt head\n",
    "    FoodAheadRight = (\n",
    "        (player.dirnx == 1 and y_agent < y_food and x_food > x_agent)\n",
    "        or (player.dirnx == -1 and y_agent > y_food and x_food < x_agent)\n",
    "        or (player.dirny == 1 and y_agent < y_food and x_food < x_agent)\n",
    "        or (player.dirny == -1 and y_agent > y_food and x_food > x_agent)\n",
    "    )\n",
    "    # Code for food which is ahead and right wrt head\n",
    "    FoodAheadLeft = (\n",
    "        (player.dirnx == 1 and y_agent > y_food and x_food > x_agent)\n",
    "        or (player.dirnx == -1 and y_agent < y_food and x_food < x_agent)\n",
    "        or (player.dirny == 1 and y_agent < y_food and x_food > x_agent)\n",
    "        or (player.dirny == -1 and y_agent > y_food and x_food < x_agent)\n",
    "    )\n",
    "    # Code for food which is ahead and right wrt head\n",
    "    FoodBehindRight = (\n",
    "        (player.dirnx == 1 and y_agent < y_food and x_food < x_agent)\n",
    "        or (player.dirnx == -1 and y_agent > y_food and x_food > x_agent)\n",
    "        or (player.dirny == 1 and y_agent > y_food and x_food < x_agent)\n",
    "        or (player.dirny == -1 and y_agent < y_food and x_food > x_agent)\n",
    "    )\n",
    "    # Code for food which is ahead and right wrt head\n",
    "    FoodBehindLeft = (\n",
    "        (player.dirnx == 1 and y_agent > y_food and x_food < x_agent)\n",
    "        or (player.dirnx == -1 and y_agent < y_food and x_food > x_agent)\n",
    "        or (player.dirny == 1 and y_agent > y_food and x_food > x_agent)\n",
    "        or (player.dirny == -1 and y_agent < y_food and x_food < x_agent)\n",
    "    )\n",
    "    # Code for food exactly behind\n",
    "    FoodBehind = (\n",
    "        (player.dirnx == 1 and y_agent == y_food and x_food < x_agent)\n",
    "        or (player.dirnx == -1 and y_agent == y_food and x_food > x_agent)\n",
    "        or (player.dirny == 1 and y_agent > y_food and x_food == x_agent)\n",
    "        or (player.dirny == -1 and y_agent < y_food and x_food == x_agent)\n",
    "    )\n",
    "    # Code for food left\n",
    "    FoodLeft = (\n",
    "        (player.dirnx == 1 and y_agent > y_food and x_food == x_agent)\n",
    "        or (player.dirnx == -1 and y_agent < y_food and x_food == x_agent)\n",
    "        or (player.dirny == 1 and y_agent == y_food and x_food > x_agent)\n",
    "        or (player.dirny == -1 and y_agent == y_food and x_food < x_agent)\n",
    "    )\n",
    "    # Code for food right\n",
    "    FoodRight = (\n",
    "        (player.dirnx == 1 and y_agent < y_food and x_food == x_agent)\n",
    "        or (player.dirnx == -1 and y_agent > y_food and x_food == x_agent)\n",
    "        or (player.dirny == 1 and y_agent == y_food and x_food < x_agent)\n",
    "        or (player.dirny == -1 and y_agent == y_food and x_food > x_agent)\n",
    "    )\n",
    "    # Adding to states list while priortizing danger over eating food\n",
    "    states = [int(x) for x in (DangerAhead, DangerLeft, DangerRight)]\n",
    "    if sum(states) == 0:\n",
    "        states += [int(x) for x in (\n",
    "            FoodStraightAhead, FoodAheadRight, FoodAheadLeft, FoodBehindRight,\n",
    "            FoodBehindLeft, FoodBehind, FoodLeft, FoodRight\n",
    "        )]\n",
    "    else:\n",
    "        states += [0] * 8\n",
    "    state = 0\n",
    "    for i in range(11):\n",
    "        if states[i] == 1:\n",
    "            state = i\n",
    "    return state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def euler_dist(x1: int, y1: int, x2: int, y2: int) -> int:\n",
    "    \"\"\"\n",
    "    For calculation of Euler Distance.\n",
    "    \"\"\"\n",
    "    dist = (x1 - x2) ** 2 + (y1 - y2) ** 2\n",
    "    return dist ** 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reward conditions:\n",
    "#   * +10 for eating food\n",
    "#   * -12 for dying\n",
    "#   * -2 for getting closer\n",
    "#   * -7 for going away from the fruit\n",
    "\n",
    "\n",
    "def reward(player: object, x_food: int, y_food: int) -> int:\n",
    "    \"\"\"\n",
    "    This function assigns the reward to the agent according to the action taken\n",
    "    \"\"\"\n",
    "    x_agent, y_agent = player.body[0].pos\n",
    "    if x_agent == x_food and y_agent == y_food:\n",
    "        return +10\n",
    "    elif (\n",
    "        (player.dirnx == -1 and x_agent + player.delx <= 0)\n",
    "        or (\n",
    "            player.dirnx == -1\n",
    "            and ((x_agent + player.delx, y_agent + player.dely) in player.body)\n",
    "        )\n",
    "        or (player.dirnx == 1 and x_agent + player.delx >= 14)\n",
    "        or (\n",
    "            player.dirnx == 1\n",
    "            and ((x_agent + player.dirnx, y_agent + player.dely) in player.body)\n",
    "        )\n",
    "        or (player.dirny == 1 and y_agent + player.dely >= 14)\n",
    "        or (\n",
    "            player.dirny == 1\n",
    "            and ((x_agent + player.delx, y_agent + player.dely) in player.body)\n",
    "        )\n",
    "        or (player.dirny == -1 and y_agent + player.dely <= 0)\n",
    "        or (\n",
    "            player.dirny == -1\n",
    "            and ((x_agent + player.delx, y_agent + player.dely) in player.body)\n",
    "        )\n",
    "        or (player.resetDone is True)\n",
    "    ):\n",
    "        return -12\n",
    "    elif (\n",
    "        euler_dist(x_agent + player.delx, y_agent + player.dely, x_food, y_food)\n",
    "        - euler_dist(x_agent, y_agent, x_food, y_food)\n",
    "        > 0\n",
    "    ):\n",
    "        return -2\n",
    "    elif (\n",
    "        euler_dist(x_agent + player.delx, y_agent + player.dely, x_food, y_food)\n",
    "        - euler_dist(x_agent, y_agent, x_food, y_food)\n",
    "        < 0\n",
    "    ):\n",
    "        return -7\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def learn(\n",
    "        state: int, action: int, reward: int, next_state: int,\n",
    "        next_action: int, i: int, trialNumber: int, epsilon: float) -> type(None):\n",
    "    \"\"\"\n",
    "    This function is for iteratively updating the Qtable.\n",
    "    \"\"\"\n",
    "    currentQ = Q_table[state][action]\n",
    "    nextQ = Q_table[next_state][next_action]\n",
    "    # Qlearning Algorithm to get new q value for the q table.\n",
    "    newQ = (1 - LEARNING_RATE) * currentQ + LEARNING_RATE * (reward + GAMMA * nextQ)\n",
    "    Q_table[state][action] = newQ\n",
    "    state = next_state\n",
    "    currentQ = nextQ\n",
    "    if trialNumber % 100 == 0:\n",
    "        print(\"Printing Q_table: \")\n",
    "        print(Q_table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
